from typing import Tuple, List
from enum import Enum
import threading
from queue import Queue
import time
import os
import numpy as np
import torch
from PIL import Image
import onnxruntime
from pycocotools import mask


USE_SAM = True
USE_GroundingDINO = False
device = "cuda"


if USE_SAM:
    from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator

if USE_GroundingDINO:
    from groundingdino.util.inference import load_model, predict, annotate
    import groundingdino.datasets.transforms as T


if USE_SAM:
    sam_checkpoint = os.environ["HOME"] + "/Models/sam_hq_vit_h.pth"
    sam_onnx = os.environ["HOME"] + "/Models/sam_hq_vit_h.onnx"
    sam_model_type = "vit_h"
    sam_ort_session = onnxruntime.InferenceSession(sam_onnx)
    sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    sam_predictor = SamPredictor(sam)

if USE_GroundingDINO:
    groundingdino_model = load_model(
        os.environ["HOME"] + "/Models/groundingdino_swint_ogc.py",
        os.environ["HOME"] + "/Models/groundingdino_swint_ogc.pth")
    groundingdino_box_threshold = 0.35
    groundingdino_text_threshold = 0.25


g_sam_embeddings = dict()
g_cached_images = dict()


class AIGCTaskDef(Enum):
    SAM_Backbone          = 1
    SAM_Point             = 2
    SAM_Box_Points_Labels = 3
    CacheImage            = 4
    GroundingDINO         = 5
    GroundedSAM           = 6
    ChatGLM               = 7


class AIGCInputDef(Enum):
    Image             = 1
    Embedding         = 2
    Text              = 3
    Point             = 4
    Box_Points_Labels = 5


class AIGCEmbedding:
    img_id = 0
    img_h = 0
    img_w = 0
    sam_embedding = None
    interm_embedding = None


def groundingdino_input(image: np.ndarray) -> Tuple[np.array, torch.Tensor]:
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_source = Image.fromarray(image)  # RGB image
    image_transformed, _ = transform(image_source, None)
    return image, image_transformed


def sam_backbone(image: np.ndarray) -> AIGCEmbedding:
    sam_predictor.set_image(image)  # RGB(uint8) Image

    embedding = AIGCEmbedding()
    embedding.sam_embedding = sam_predictor.get_image_embedding().cpu().numpy()
    embedding.interm_embedding = torch.stack(sam_predictor.interm_features).cpu().numpy()
    embedding.img_h = image.shape[0]
    # print(embedding.img_h)
    embedding.img_w = image.shape[1]
    # print(embedding.img_w)
    return embedding


def sam_point(sam_embedding: np.ndarray, interm_embedding: np.ndarray, px: float, py: float, img_w: int, img_h: int):
    input_point = np.array([[px * img_w, py * img_h]])
    input_label = np.array([1])

    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(
        np.float32)
    onnx_coord = sam_predictor.transform.apply_coords(onnx_coord, (img_h, img_w)).astype(
        np.float32)

    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": sam_embedding,
        "interm_embeddings": interm_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array((img_h, img_w), dtype=np.float32)
    }

    masks, _, low_res_logits = sam_ort_session.run(None, ort_inputs)
    masks = masks > sam_predictor.model.mask_threshold
    masks = masks.squeeze()

    enc_mask = mask.encode(masks.copy(order='F'))
    return enc_mask['counts']


def sam_box_points_labels(sam_embedding: np.ndarray,
                          interm_embedding: np.ndarray,
                          box: np.ndarray, pts: np.ndarray, labels: np.ndarray, img_w: int, img_h: int):
    if len(box) == 4:
        box[0::2] *= img_w
        box[1::2] *= img_h
        onnx_box_coords = box.reshape(2, 2)
        onnx_box_labels = np.array([2, 3])
    else:
        onnx_box_coords = np.array([[0.0, 0.0]])
        onnx_box_labels = np.array([-1])

    pts[:, 0] *= img_w
    pts[:, 1] *= img_h
    onnx_coord = np.concatenate([pts, onnx_box_coords], axis=0)[None, :, :]
    onnx_label = np.concatenate([labels, onnx_box_labels], axis=0)[None, :].astype(
        np.float32)
    onnx_coord = sam_predictor.transform.apply_coords(onnx_coord, (img_h, img_w)).astype(
        np.float32)
    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": sam_embedding,
        "interm_embeddings": interm_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array((img_h, img_w), dtype=np.float32)
    }

    masks, _, _ = sam_ort_session.run(None, ort_inputs)
    masks = masks > sam_predictor.model.mask_threshold
    masks = masks.squeeze()
    enc_mask = mask.encode(masks.copy(order='F'))
    return enc_mask['counts']


class AIGCTask:
    uid            = None
    req_id         = None
    task           = None
    image          = None
    # embedding    = None
    text           = None
    point          = None
    points         = None
    labels         = None
    box            = None
    results        = None

    def __init__(self, uid: str, task: AIGCTaskDef) -> None:  # User Identity
        self.uid = uid
        self.task = task

    def run(self) -> bool:
        done = True
        try:
            if   self.task == AIGCTaskDef.SAM_Backbone:
                assert USE_SAM
                assert self.image is not None
                # RGB(uint8) Image
                g_sam_embeddings[self.uid] = sam_backbone(self.image)
            elif self.task == AIGCTaskDef.SAM_Point:
                assert USE_SAM
                # assert self.embedding is not None
                assert self.point is not None
                if self.uid in g_sam_embeddings:
                    mask_enc = sam_point(
                        g_sam_embeddings[self.uid].sam_embedding,
                        g_sam_embeddings[self.uid].interm_embedding,
                        self.point[0],
                        self.point[1],
                        g_sam_embeddings[self.uid].img_w,
                        g_sam_embeddings[self.uid].img_h
                    )
                    self.results = mask_enc
                else:
                    done = False
            elif self.task == AIGCTaskDef.SAM_Box_Points_Labels:
                assert USE_SAM
                # assert self.embedding is not None
                assert self.box is not None
                assert self.points is not None
                assert self.labels is not None
                if self.uid in g_sam_embeddings:
                    mask_enc = sam_box_points_labels(
                        g_sam_embeddings[self.uid].sam_embedding,
                        g_sam_embeddings[self.uid].interm_embedding,
                        self.box,
                        self.points,
                        self.labels,
                        g_sam_embeddings[self.uid].img_w,
                        g_sam_embeddings[self.uid].img_h
                    )
                    self.results = mask_enc
                else:
                    done = False
            elif self.task == AIGCTaskDef.CacheImage:
                assert self.image is not None
                # RGB(uint8) Image
                g_cached_images[self.uid] = self.image
            elif self.task == AIGCTaskDef.GroundingDINO:
                assert USE_GroundingDINO
                assert self.text is not None
                # print(self.text)
                if self.uid in g_cached_images:
                    image_source, image = groundingdino_input(g_cached_images[self.uid])
                    boxes, logits, phrases = predict(
                        model=groundingdino_model,
                        image=image,
                        caption=self.text,
                        box_threshold=groundingdino_box_threshold,
                        text_threshold=groundingdino_text_threshold
                    )
                    self.results = {'boxes': boxes.cpu().numpy(), 'logits': logits.cpu().numpy(), 'phrases': phrases}
                    # print(self.results)
                else:
                    done = False

        except Exception as e:
            done = False

        return done

    def set_req_id(self, id_: int) -> None:
        self.req_id = id_

    def set_image(self, image: np.ndarray) -> None:
        self.image = image

    def set_text(self, text: str) -> None:
        self.text = text

    def set_point(self, point: np.ndarray) -> None:
        self.point = point

    def set_box_points_labels(self, box: np.ndarray, points: np.ndarray, labels: np.ndarray) -> None:
        self.box = box
        self.points = points
        self.labels = labels


class AIGCSocket:
    def __init__(self):
        pass

    def response_info(self, task: AIGCTask):
        pass

    def response_error(self, task: AIGCTask):
        pass


class AIGCTaskManager(threading.Thread):
    task_queue = Queue()
    running = False

    def __init__(self, socket: AIGCSocket):
        threading.Thread.__init__(self)
        self.socket = socket
        self.running = True

    def run(self):
        while self.running:
            if not self.task_queue.empty():
                task = self.task_queue.get()
                if task.run():
                    self.socket.response_info(task)
                else:
                    self.socket.response_error(task)
            else:
                time.sleep(0.1)

    def add_task(self, task: AIGCTask):
        self.task_queue.put(task)


if __name__ == '__main__':
    t = AIGCTask('qwer', AIGCTaskDef.SAM_Backbone)
    print(t.uid)
    print(t.task)
    print(t.box)
    print(t.run())
